import json
from collections import defaultdict
import os
import numpy as np
from tqdm import tqdm
from azfuse import File

from multiprocessing import Pool, cpu_count


def dedup_by_iou(annotation_file="data/vqav2/vqa_k_test_noun.jsonl", image_folder="data/vqav2/images/", iou_threshold=0.7):
    # load jsonl data
    data = [json.loads(line.strip()) for line in File.open(annotation_file, "r")]
    output_file = os.path.join(os.path.dirname(annotation_file), f"{os.path.basename(annotation_file).replace('.jsonl', '_dedup.jsonl')}")
    # get all image paths
    # sammpled_data
    # {"question": "What is this man taking a picture of?", "question_id": 1655007, "labels": {"mountain": 0.6, "mountains": 1, "scenery": 0.3}, "image_id": 1655, "answer_type": "other", "question_type": "what is this", "file_id": "COCO_val2014_000000001655", "noun_ans": "mountains", "gpt_noun_ans": "mountains"}
    seg_folder = os.path.join(image_folder, "remove_anything/gsam_masks")
    file_id2q = defaultdict(list)
    for d in data:
        file_id2q[d["file_id"]].append(d)
    output = []
    for file_id, questions in tqdm(file_id2q.items()):
        if len(questions) == 1:
            output.append(questions[0])
            continue
        # get_all_segementations
        seg_files = {idx: os.path.join(seg_folder, f"{file_id}- {q['question_id']}_mask.npy") for idx, q in enumerate(questions)}
        segs = {idx: np.load(f) for idx, f in seg_files.items() if os.path.exists(f)}
        # get all ious
        all_valid_qidx = list(segs.keys())
        for i in range(len(all_valid_qidx)):
            remove = False
            curr_qidx = all_valid_qidx[i]
            seg = segs[curr_qidx]
            for j in range(0, i):
                candid_qidx = all_valid_qidx[j]
                candid_seg = segs[candid_qidx]
                iou = compute_iou(seg, candid_seg)
                if iou >= iou_threshold:
                    remove = True
                    break
            if not remove:
                output.append(questions[curr_qidx])
    with File.open(output_file, "w") as f:
        for d in output:
            f.write(json.dumps(d) + "\n")


def compute_iou(seg1, seg2):
    if seg1.shape[0] != seg2.shape[0]:
        seg1 = np.any(seg1, axis=0)
        seg2 = np.any(seg2, axis=0)
    intersection = np.logical_and(seg1, seg2)
    union = np.logical_or(seg1, seg2)
    return np.sum(intersection) / np.sum(union)



def sample_one_per_image(dedup_jsonl="data/vqav2/vqa_k_test_noun_dedup.jsonl"):
    output_jsonl = dedup_jsonl.replace(".jsonl", "_sampled_1.jsonl")
    data = [json.loads(line.strip()) for line in File.open(dedup_jsonl, "r")]
    file_id2q = defaultdict(list)
    for d in data:
        file_id2q[d["file_id"]].append(d)
    output = []
    for file_id, questions in tqdm(file_id2q.items()):
        output.append(questions[0])
    with File.open(output_jsonl, "w") as f:
        for d in output:
            f.write(json.dumps(d) + "\n")


import json


def convert_perturb_q_to_json(input_str):
    # Split the string by lines
    lines = input_str.split('\n')
    
    # Initialize an empty dictionary to hold the parsed data
    data_dict = {"Q": "", "A_orig": "", "A_perturb": ""}
    
    # Iterate over each line and parse the question and answers
    for line in lines:
        line = line.replace("-", "")
        line = line.strip()
        if line.startswith("Q:"):
            data_dict["Q"] = line[2:].strip()  # Remove '- Q:' and leading/trailing spaces
        elif line.startswith("A1:"):
            data_dict["A_orig"] = line[3:].strip()  # Remove '- A1:' and leading/trailing spaces
        elif line.startswith("A2:"):
            data_dict["A_perturb"] = line[3:].strip()  # Remove '- A2:' and leading/trailing spaces
        else:
            # print(f"WARNING: Unrecognized line '{line}'")
            raise ValueError(f"Unrecognized line '{line}'")
    
    # Convert the dictionary to JSON format
    return data_dict


def convert_caption_q_to_json(input_str):
    # Split the string by lines
    lines = input_str.split('\n')
    
    # Initialize an empty dictionary to hold the parsed data
    data_dict = {"Q1": "", "A1": "", "Q2": "", "A2": ""}
    
    # Iterate over each line and parse the question and answers
    for line in lines:
        line = line.replace("-", "")
        line = line.strip()
        if line.startswith("Q1:"):
            data_dict["Q1"] = line[2:].strip()  # Remove '- Q:' and leading/trailing spaces
        elif line.startswith("Q2:"):
            data_dict["Q2"] = line[3:].strip()  # Remove '- A1:' and leading/trailing spaces
        elif line.startswith("A1:"):
            data_dict["A1"] = line[3:].strip()  # Remove '- A1:' and leading/trailing spaces
        elif line.startswith("A2:"):
            data_dict["A2"] = line[3:].strip()  # Remove '- A2:' and leading/trailing spaces
        elif len(line.strip()):
            # print(f"WARNING: Unrecognized line '{line}'")
            raise ValueError(f"Unrecognized line '{line}'")
    
    # Convert the dictionary to JSON format
    return data_dict


def process_annotation_perturb_q(annotation, input_folder, include_orig):
    try:
        image_path = os.path.join(input_folder, annotation['file_id'] + '.jpg')
        mask_image_path = os.path.join(input_folder, f"{annotation['file_id']}-{annotation['question_id']}_remove_0.png")
        gpt4v_output_path = mask_image_path.replace(".png", ".txt")

        if not File.isfile(image_path) or not File.isfile(mask_image_path) or not File.isfile(gpt4v_output_path):
            return None

        with File.open(gpt4v_output_path, "r") as f:
            gpt4v_gen_str = f.read().strip()
            gpt4v_gen_json = convert_perturb_q_to_json(gpt4v_gen_str)
            out_dict_orig = {"id": os.path.basename(image_path).replace(".jpg", ""), "image": os.path.basename(image_path), "conversations": [{"from": "human", "value": "<image>\n" + gpt4v_gen_json["Q"]}, {"from": "gpt", "value": gpt4v_gen_json["A_orig"]}]}
            out_dict_perturb = {"id": os.path.basename(mask_image_path).replace(".png", ""), "image": os.path.basename(mask_image_path), "conversations": [{"from": "human", "value": "<image>\n" + gpt4v_gen_json["Q"]}, {"from": "gpt", "value": gpt4v_gen_json["A_perturb"]}]}
            
            output = []
            if include_orig:
                output.append(out_dict_orig)
            output.append(out_dict_perturb)
            return output
    except Exception as e:
        print(f"Error processing annotation {annotation['file_id']}: {e}")
        return None


def process_annotation_docci_caption_q(annotation, input_folder, image_sub_folder, question_type):
    try:
        question_sub_folder = f"docci/gpt4_gen_{question_type}_q"
        image_path = os.path.join(input_folder, image_sub_folder,  annotation['example_id'] + '.jpg')
        q_path = os.path.join(input_folder, question_sub_folder, f"{annotation['example_id']}.txt")

        if not File.isfile(image_path) or not File.isfile(q_path):
            if not File.isfile(image_path):
                print(f"File not found: {image_path}")
            else:
                print(f"File not found: {q_path}")
            return None

        with File.open(q_path, "r") as f:
            q_str = f.read().strip()
            q_json = convert_caption_q_to_json(q_str)
            id_ = os.path.basename(image_path).replace(".jpg", "")
            q1 = q_json["Q1"]
            a1 = q_json["A1"]
            if "not a good question" in a1.lower():
                return None
            elif "caption" in a1.lower():
                return None
            if "unanswerable" in a1.lower():
                id1 = id_ + f"_{question_type}_unanswerable"
            else:
                id1 = id_ + f"_{question_type}_answerable"
            q2 = q_json["Q2"]
            a2 = q_json["A2"]
            if "not a good question" in a2.lower():
                return None
            elif "caption" in a2.lower():
                return None
            if "unanswerable" in a2.lower():
                id2 = id_ + f"_{question_type}_unanswerable"
            else:
                id2 = id_ + f"_{question_type}_answerable"
            if not((id1.endswith("unanswerable") and id2.endswith("answerable")) or (id2.endswith("unanswerable") and id1.endswith("answerable"))):
                print(annotation['example_id'], q_json)
                return None
            out_dict1 = {"id": id1, "image": os.path.join(image_sub_folder, annotation['example_id'] + '.jpg'), "conversations": [{"from": "human", "value": "<image>\n" + q1}, {"from": "gpt", "value": a1}]}
            out_dict2 = {"id": id2, "image": os.path.join(image_sub_folder, annotation['example_id'] + '.jpg'), "conversations": [{"from": "human", "value": "<image>\n" + q2}, {"from": "gpt", "value": a2}]}
            output = []
            output.append(out_dict1)
            output.append(out_dict2)
            return output
    except Exception as e:
        print(f"Error processing annotation {annotation['example_id']}: {e}")
        return None


def split_docci_into_train_test_dev(input_file):
    with File.open(input_file, "r") as f:
        data = json.load(f)
    train, dev, test = [], [], []
    for d in data:
        if "train" in d["id"]:
            train.append(d)
        elif "dev" in d["id"]:
            dev.append(d)
        elif "test" in d["id"]:
            test.append(d)
        else:
            print(d["id"])
    with File.open(input_file.replace(".json", "_train.json"), "w") as f:
        json.dump(train, f, indent=4)
    with File.open(input_file.replace(".json", "_dev.json"), "w") as f:
        json.dump(dev, f, indent=4)
    with File.open(input_file.replace(".json", "_test.json"), "w") as f:
        json.dump(test, f, indent=4)


def convert_to_llava_format(output_file, annotation_file="data/vqav2/vqa_k_test_noun_dedup_sampled_1.jsonl", input_folder="<DATA_FOLDER>/vqav2/remove_anything/lama-gpt4v_gen_q/", debug=False, include_orig=True):
    # '''
    #     {
    #     "id": "000000157875",
    #     "image": "000000157875.jpg",
    #     "conversations": [
    #     {
    #         "from": "human",
    #         "value": "<image>\nWhat activity could develop the young girl's physical and cognitive abilities?"
    #     },
    #     {
    #         "from": "gpt",
    #         "value": "Flying a kite, like in the image, can be a fun activity that helps develop a young girl's physical and cognitive abilities. This activity encourages physical movement, such as running in open spaces, and helps improve hand-eye coordination as the child navigates the kite in the sky. Additionally, flying a kite requires problem-solving and strategic thinking, as the child must understand wind patterns and make adjustments to maintain the kite's flight. Overall, kite flying not only serves as a recreational activity but also contributes to the child's growth and development."
    #     }
    #     ],
    # },
    # '''
    # annotations = [json.loads(line) for line in File.open(annotation_file)]
    # print (f"Total number of annotations: {len(annotations)}")
    # orig_image_files = []
    # for annotation in annotations:
    #     image_path = os.path.join(input_folder, annotation['file_id'] + '.jpg')
    #     orig_image_files.append(image_path)
    # # orig_image_files = [os.path.join(orig_img_folder, f"{annotation['file_id']}.jpg") for annotation in annotations]
    # #data/vqav2/images/remove_anything/lama/COCO_val2014_000000000810-810001_remove_0.png
    # mask_image_files = [os.path.join(input_folder, f"{annotation['file_id']}-{annotation['question_id']}_remove_0.png") for annotation in annotations]
    # gpt4v_gen = [im_name.replace(".png", ".txt") for _, im_name in enumerate(mask_image_files)]
    # File.prepare([f for idx, f in enumerate(gpt4v_gen) if File.isfile(f)])
    # #  and File.isfile(mask_image_files[idx]) and File.isfile(orig_image_files[idx])])
    # if debug:
    #     mask_image_files = mask_image_files[:100]
    # print(f"Total number of images: {len(mask_image_files)}")
    # output = []
    # for idx, img_file in tqdm(enumerate(mask_image_files), total=len(mask_image_files)):
    #     gpt4v_output = gpt4v_gen[idx]
    #     orig_img = orig_image_files[idx]
    #     if not File.isfile(orig_img):
    #         print(f"File not found: {orig_img}")
    #         continue
    #     if not File.isfile(img_file):
    #         print(f"File not found: {img_file}")
    #         continue
    #     if not File.isfile(gpt4v_output):
    #         print(f"File not found: {gpt4v_output}")
    #         continue
    #     with File.open(gpt4v_output, "r") as f:
    #         gpt4v_gen_str = f.read().strip()
    #         try:
    #             gpt4v_gen_json = convert_to_json(gpt4v_gen_str)
    #             out_dict_orig = {"id": os.path.basename(orig_img).replace(".jpg", ""), "image": os.path.basename(orig_img), "conversations": [{"from": "human", "value": "<image>\n" + gpt4v_gen_json["Q"]}, {"from": "gpt", "value": gpt4v_gen_json["A_orig"]}]}
    #             out_dict_perturb = {"id": os.path.basename(img_file).replace(".png", ""), "image": os.path.basename(img_file), "conversations": [{"from": "human", "value": "<image>\n" + gpt4v_gen_json["Q"]}, {"from": "gpt", "value": gpt4v_gen_json["A_perturb"]}]}
    #             if include_orig:
    #                 output.append(out_dict_orig)
    #             output.append(out_dict_perturb)
    #         except ValueError as e:
    #             print(f"Error: {e}")
    #             print(f"Error in file: {output_file}")
    #             continue
    # with File.open(output_file, "w") as f:
    #     json.dump(output, f, indent=4)

    with File.open(annotation_file, "r") as f:
        annotations = [json.loads(line) for line in f]

    if debug:
        annotations = annotations[:100]  # For debugging, process a subset
    
    mask_image_files = [os.path.join(input_folder, f"{annotation['file_id']}-{annotation['question_id']}_remove_0.png") for annotation in annotations]
    gpt4v_gen = [im_name.replace(".png", ".txt") for _, im_name in enumerate(mask_image_files)]
    File.prepare([f for idx, f in enumerate(gpt4v_gen) if File.isfile(f)])

    print(f"Total number of annotations: {len(annotations)}")

    # Create a pool of worker processes
    with Pool(cpu_count()) as pool:
        results = pool.starmap(process_annotation_perturb_q, [(annotation, input_folder, include_orig) for annotation in annotations])

    # Flatten the list of results and remove None values
    output = [item for sublist in results if sublist is not None for item in sublist]

    print(f"Total number of processed annotations: {len(output)}")

    with File.open(output_file, "w") as f:
        json.dump(output, f, indent=4)


def convert_docci_caption_based_to_llava_format(output_file, annotation_file="<DATA_FOLDER>/docci/docci_descriptions.jsonlines", input_folder="<DATA_FOLDER>/", image_sub_folder="docci/images", question_type="complex",  debug=False):
    with File.open(annotation_file, "r") as f:
        annotations = [json.loads(line) for line in f]

    if debug:
        annotations = annotations[:100]  # For debugging, process a subset

    print(f"Total number of annotations: {len(annotations)}")

    # Create a pool of worker processes
    with Pool(cpu_count()) as pool:
        results = pool.starmap(process_annotation_docci_caption_q, [(annotation, input_folder, image_sub_folder, question_type) for annotation in annotations])

    # Flatten the list of results and remove None values
    output = [item for sublist in results if sublist is not None for item in sublist]

    print(f"Total number of processed annotations: {len(output)}")

    with File.open(output_file, "w") as f:
        json.dump(output, f, indent=4)


def convert_llava_to_jsonl(input_file, output_file, random_perturb=False):
    with File.open(input_file, "r") as f:
        llava_data = json.load(f)
    
    output = []
    for item in llava_data:
        category = "unk" if "remove" in item["image"] else "know"
        conversation = item["conversations"]
        image = item["image"]
        perturb_image = item.get("perturb_image", [])
        if random_perturb and len(perturb_image) > 0 and category =="know":
            image = random.choice(perturb_image)
            print(f"Randomly select image from : {perturb_image}")
        output.append({"question_id": item["id"], "image": image, "text": conversation[0]["value"].split("\n")[1], "answer": conversation[1]["value"], "category": category})
    
    with File.open(output_file, "w") as f:
        for item in output:
            f.write(json.dumps(item) + "\n")


def convert_docci_to_eval_jsonl(input_file, output_file):
    with File.open(input_file, "r") as f:
        llava_data = json.load(f)
    
    output = []
    for item in llava_data:
        category = "unk" if "unanswerable" in item["id"] else "know"
        conversation = item["conversations"]
        image = item["image"]
        output.append({"question_id": item["id"], "image": image, "text": conversation[0]["value"].split("\n")[1], "answer": conversation[1]["value"], "category": category})
    
    with File.open(output_file, "w") as f:
        for item in output:
            f.write(json.dumps(item) + "\n")


def random_shuffle_and_sample_5k_for_docci():
    input_folder = "<DATA_FOLDER>/docci/"
    for question_type in ["complex", "know", "pred", "ambiguity"]:
        q_path = os.path.join(input_folder, f"docci_{question_type}_test.eval.jsonl")
        data = [json.loads(line.strip()) for line in File.open(q_path, "r")]
        paired_index = list(range(0, len(data), 2))
        if len(paired_index) > 2500:
            random.shuffle(paired_index)
            paired_index = paired_index[:2500]
        sampled_data = []
        for i in paired_index:
            sampled_data.append(data[i])
            sampled_data.append(data[i+1])
        with File.open(os.path.join(input_folder, f"docci_{question_type}_test.5k.eval.jsonl"), "w") as f:
            for d in sampled_data:
                f.write(json.dumps(d) + "\n")


def train_split_into_5k_for_docci():
    import math
    input_folder = "<DATA_FOLDER>/docci/"
    for question_type in ["_complex", "_know", "_pred", "_ambiguity", ""]:
        q_path = os.path.join(input_folder, f"docci{question_type}_train.eval.jsonl")
        data = [json.loads(line.strip()) for line in File.open(q_path, "r")]
        num_splits = int(math.ceil(len(data) / 5000.))
        for s in range(num_splits):
            start_index = s * 5000
            end_index = (s+1)*5000
            sampled_data = data[start_index:end_index]
            with File.open(os.path.join(input_folder, f"docci{question_type}_train.{s}.eval.jsonl"), "w") as f:
                for d in sampled_data:
                    f.write(json.dumps(d) + "\n")


def train_merge_docci():
    input_folder = "<DATA_FOLDER>/docci/"
    data = []
    for question_type in ["complex", "know", "pred", "ambiguity"]:
        q_path = os.path.join(input_folder, f"docci_{question_type}_train.json")
        sampled_data = json.load(File.open(q_path, "r"))
        data.extend(sampled_data)
    with File.open(os.path.join(input_folder, f"ours_caption_based.json"), "w") as f:
        json.dump(data, f)

    # merge docci with idk
    perturb_idk = "<DATA_FOLDER>unk_v1+gqa.json"
    with File.open(perturb_idk, "r") as f:
        perturb_idk_data = json.load(f)
    data.extend(perturb_idk_data)
    # with File.open(f"<DATA_FOLDER>unk_v1+gqa+ours_caption_based.json", "w") as f:
    #     json.dump(data, f)
    # convert_llava_to_qwen_vl(
    #     llava_json="<DATA_FOLDER>unk_v1+gqa+ours_caption_based.json",
    #     output_file="<DATA_FOLDER>unk_v1+gqa+docci_train_for_qwenvl.json",
    #     image_folder="<DATA_FOLDER>"
    # )
    # 153260
    
    llava_data_path = "<DATA_FOLDER>/llava_v1_5_mix665k.json"
    with File.open(llava_data_path, "r") as f:
        llava_data = json.load(f)
        for d in tqdm(llava_data, desc="processing llava data"):
            if "image" in d:
                if "vg" in d["image"]:
                    d["image"] = d["image"].replace("VG_100K_2", "VG_100K")
                    # print(d["image"])
            data.append(d)
    with File.open(f"<DATA_FOLDER>ours+llava.json", "w") as f:
        json.dump(data, f)
    # 818558
    convert_llava_to_qwen_vl(
        llava_json="<DATA_FOLDER>ours+llava.json",
        output_file="<DATA_FOLDER>unk_v1+gqa+ours_caption_based+llava_v1_5_mix665k_for_qwenvl.json",
        image_folder="<DATA_FOLDER>"
    )
    # 683694


def convert_docci_to_dpo_format():
    input_folder = "<DATA_FOLDER>/docci/"
    data = [json.loads(line) for line in File.open(os.path.join(input_folder, f"ours_caption_based.eval.jsonl"))]
    qid_pairs = []
    qid2q = {}
    for idx in range(0, len(data), 2):
        qid1 = data[idx]["question_id"]
        qid2q[qid1] = data[idx]
        qid2 = data[idx+1]["question_id"]
        qid2q[qid2] = data[idx+1]
        assert data[idx]["image"] == data[idx+1]["image"]
        if "unanswerable" in qid1:
            qid_pairs.append((qid1, qid2))
        else:
            qid_pairs.append((qid2, qid1))

    data_dict = []
    for idd_count, (idk_qid, k_qid) in enumerate(qid_pairs):
        q_idk = qid2q[idk_qid]
        q_k = qid2q[k_qid]
        dp = {}
        dp["id"] = idk_qid
        question = q_idk["text"]
        if "<image>" not in question:
            question = "<image>\n" + question
        a_idk = q_idk["answer"]
        a_k = q_k["answer"]
        data_dict.append(
            {
            "id": f"hadpo_docci_{k_qid}",
            "image": q_k["image"],
            "chosen_conversations": [
                {"from": "human", "value": question},
                {"from": "gpt", "value": a_k},
            ],
            "reject_conversations": [
                {"from": "human", "value": question},
                {"from": "gpt", "value": a_idk},
            ],
            }
        )
        data_dict.append(
            {
            "id": f"hadpo_docci_{idk_qid}",
            "image": q_k["image"],
            "chosen_conversations": [
                {"from": "human", "value": question},
                {"from": "gpt", "value": a_idk},
            ],
            "reject_conversations": [
                {"from": "human", "value": question},
                {"from": "gpt", "value": a_k},
            ],
            }
        )
    with File.open(f"<DATA_FOLDER>/hadpo-data/hadpo/llava-v1.5/ours_caption_based.dpo.json", "w") as f:
        json.dump(data_dict, f)
    
    unk_dpo = "<DATA_FOLDER>/hadpo-data/hadpo/llava-v1.5/ours_image_based.json"
    with File.open(unk_dpo, "r") as f:
        unk_dpo_data = json.load(f)
    data_dict.extend(unk_dpo_data)
    with File.open(f"<DATA_FOLDER>/hadpo-data/hadpo/llava-v1.5/ours_image_based+ours_caption_based.dpo.json", "w") as f:
        json.dump(data_dict, f)
    
    hadpo = "<DATA_FOLDER>/hadpo-data/hadpo/llava-v1.5/pope+desc_data.json"
    with File.open(hadpo, "r") as f:
        hadpo_data = json.load(f)
    hadpo_all = []
    hadpo_all.extend(data_dict)
    hadpo_all.extend(hadpo_data)
    with File.open(f"<DATA_FOLDER>/hadpo-data/hadpo/llava-v1.5/ours_image_based+ours_caption_based+pope+desc_data.dpo.json", "w") as f:
        json.dump(hadpo_all, f)
    
    silkie = "<DATA_FOLDER>/hadpo-data/hadpo/llava-v1.5/silkie.json"
    with File.open(silkie, "r") as f:
        silkie_data = json.load(f)
    silkie_all = []
    silkie_all.extend(data_dict)
    silkie_all.extend(silkie_data)
    with File.open(f"<DATA_FOLDER>/hadpo-data/hadpo/llava-v1.5/ours_image_based+ours_caption_based+silkie.dpo.json", "w") as f:
        json.dump(silkie_all, f)


def find_answerable_image_by_iou(annotation_file, image_folder="data/vqav2/images/", iou_threshold=0.5, debug=False):
    # load jsonl data
    data = json.load(File.open(annotation_file, "r"))
    if debug:
        data = data[:100]
    output_file = os.path.join(os.path.dirname(annotation_file), f"{os.path.basename(annotation_file).replace('.json', '.rand_perturb_answerable.json')}")
    output_data = []
    seg_folder = os.path.join(image_folder, "remove_anything/gsam_masks")
    for idx in range(0, len(data), 2):
        orig = data[idx]
        perturb = data[idx + 1]
        chunk_output = process_chunk([(orig, perturb)], image_folder, iou_threshold, seg_folder)
        output_data.extend(chunk_output)
        
    with File.open(output_file, "w") as f:
        json.dump(output_data, f, indent=4)


import json
import os
# import glob
import numpy as np
import random
from multiprocessing import Pool

def process_chunk(data_chunk, image_folder, iou_threshold, seg_folder):
    all_seg_files = [f for f in File.list(seg_folder) if f.endswith("_mask.npy")]
    output_data_chunk = []
    for data_pair in tqdm(data_chunk):
        orig, perturb = data_pair
        perturb_file_id = perturb["image"].replace(".png", "").replace("_remove_0", "")
        qid = perturb_file_id.split("-")[-1]
        orig_file_id = orig["image"].replace(".jpg", "")
        # seg_files = glob.glob(os.path.join(seg_folder, f"{orig_file_id}-*.npy"))
        seg_files = [f for f in all_seg_files if os.path.basename(f).startswith(orig_file_id)]
        seg_qid = {os.path.basename(f).split("- ")[-1].replace("_mask.npy", ""): f for f in seg_files}
        seg_qid = {id_: v for id_, v in seg_qid.items() if id_ != qid}
        gt_seg = np.load(File.open(os.path.join(seg_folder, f"{orig_file_id}- {qid}_mask.npy"), "rb"))
        # random.shuffle(seg_qid)
        # random shuffle a dict
        # seg_qid = dict(random.sample(seg_qid.items(), len(seg_qid)))
        candid_qids = list(seg_qid.keys())
        random.shuffle(candid_qids)
        orig["perturb_image"] = []
        for id_ in candid_qids:
            file_path = seg_qid[id_]
            if not File.isfile(os.path.join(image_folder, f"{orig_file_id}-{id_}_remove_0.png")) or not File.isfile(file_path):
                print(f"File not found: {orig_file_id}-{id_}")
                continue
            if File.isfile(os.path.join(file_path, f"{orig_file_id}- {id_}_mask.npy")):
                file_path = os.path.join(file_path, f"{orig_file_id}- {id_}_mask.npy")
            seg = np.load(File.open(file_path, "rb"))
            iou = compute_iou(gt_seg, seg)
            print(f"iou: {iou}")
            if iou < iou_threshold:
                orig["perturb_image"].append(f"{orig_file_id}-{id_}_remove_0.png")
                print(f"Found answerable image: {orig['perturb_image']}")
                if len(orig["perturb_image"]) >= 5:
                    break
        
        output_data_chunk.append(orig)
        output_data_chunk.append(perturb)
    return output_data_chunk

def find_answerable_image_by_iou_multiprocess(annotation_file, image_folder="data/vqav2/images", iou_threshold=0.5, num_processes=16, debug=False):
    # Load jsonl data
    with File.open(annotation_file, "r") as f:
        data = json.load(f)
    if debug:
        data = data[:100]
    output_file = os.path.join(os.path.dirname(annotation_file), f"{os.path.basename(annotation_file).replace('.json', '.rand_perturb_answerable.json')}")

    seg_folder = os.path.join(image_folder, "remove_anything/gsam_masks")
    perturb_image_folder = os.path.join(image_folder, "remove_anything/lama")
    
    # Split data into chunks for multiprocessing
    data_pairs = [(data[i], data[i+1]) for i in range(0, len(data), 2)]
    chunk_size = len(data_pairs) // num_processes + (len(data_pairs) % num_processes > 0)
    data_chunks = [data_pairs[i:i + chunk_size] for i in range(0, len(data_pairs), chunk_size)]
    
    # Process data in parallel
    with Pool(num_processes) as pool:
        results = pool.starmap(process_chunk, [(chunk, perturb_image_folder, iou_threshold, seg_folder) for chunk in data_chunks])
    
    # Merge results from all processes
    output_data = [item for sublist in results for item in sublist]
    
    # Write output data to file
    with File.open(output_file, "w") as f:
        json.dump(output_data, f, indent=4)


def get_perturbed_image_list(json_file):
    all_perturbed_image = []
    with File.open(json_file, "r") as f:
        data = json.load(f)
    
    for d in tqdm(data):
        if "perturb_image" not in d:
            continue
        all_perturbed_image.extend(d["perturb_image"])
    with File.open("perturbed_image_list.txt", "w") as f:
        for img in all_perturbed_image:
            f.write(f"{img}\n")


def convert_llava_to_qwen_vl(llava_json, output_file, image_folder="<DATA_FOLDER>/lama-gpt4v_gen_q/"):
    '''
    [
        {
            "id": "identity_0",
            "conversations": [
            {
                "from": "user",
                "value": "你好"
            },
            {
                "from": "assistant",
                "value": "我是Qwen-VL,一个支持视觉输入的大模型。"
            }
            ]
        },
        {
            "id": "identity_1",
            "conversations": [
            {
                "from": "user",
                "value": "Picture 1: <img>https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg</img>\n图中的狗是什么品种？"
            },
            {
                "from": "assistant",
                "value": "图中是一只拉布拉多犬。"
            },
            {
                "from": "user",
                "value": "框出图中的格子衬衫"
            },
            {
                "from": "assistant",
                "value": "<ref>格子衬衫</ref><box>(588,499),(725,789)</box>"
            }
            ]
        },
        { 
            "id": "identity_2",
            "conversations": [
            {
                "from": "user",
                "value": "Picture 1: <img>assets/mm_tutorial/Chongqing.jpeg</img>\nPicture 2: <img>assets/mm_tutorial/Beijing.jpeg</img>\n图中都是哪"
            },
            {
                "from": "assistant",
                "value": "第一张图片是重庆的城市天际线，第二张图片是北京的天际线。"
            }
            ]
        }
        ]
    '''
    with File.open(llava_json, "r") as f:
        llava_data = json.load(f)
    print("Len: ", len(llava_data))
    output = []
    for item in tqdm(llava_data):
        id_ = item["id"]
        
        conversation = item["conversations"]
        skip = False
        if "image" in item:
            image = item["image"]
            image = os.path.join(image_folder, image)
            if "VG_100K_2/" in image:
                image = image.replace("VG_100K_2/", "VG_100K/")
            new_convs = []
            for conv in conversation:
                if conv["from"] == "human":
                    if "<image>" in conv["value"]:
                        if "[0." in conv["value"]:
                            # print(conv["value"])
                            skip = True
                            break
                        value = conv["value"].replace("<image>\n", "")
                        value = conv["value"].replace("<image>", "")
                        value = f"Picture 1: <img>{image}</img>\n{value}"
                    else:
                        value = conv["value"]
                    new_conv = {"from": "user", "value": value}
                else:
                    if "[0." in conv["value"]:
                        # print(conv["value"])
                        skip = True
                        break
                    new_conv = {"from": "assistant", "value": conv["value"]}
                if "<img>" in new_conv["value"]:
                    assert "</img>" in new_conv["value"], new_conv["value"]
                assert "<image>" not in new_conv["value"], f"old: {conv['value']}, new: {new_conv['value']}"
                new_convs.append(new_conv)
        else:
            new_convs = []
            for conv in conversation:
                if conv["from"] == "human":
                    value = conv["value"]
                    new_conv = {"from": "user", "value": f"{value}"}
                else:
                    new_conv = {"from": "assistant", "value": conv["value"]}
                new_convs.append(new_conv)
        if not skip:
            output.append({"id": id_, "conversations": new_convs})
    print("Len: ", len(output))
    
    with File.open(output_file, "w") as f:
        json.dump(output, f)


def shuffle_llava_data(llava_json):
    shuffle_output = llava_json.replace(".json", ".shuffle.json")
    if not File.isfile(shuffle_output):
        with File.open(llava_json, "r") as f:
            llava_data = json.load(f)
        print(f"Len: {len(llava_data)}")
        index = list(range(len(llava_data)))
        random.shuffle(index)
        llava_data = [llava_data[i] for i in index]

        with File.open(llava_json.replace(".json", ".shuffle.json"), "w") as f:
            json.dump(llava_data, f)
    
    split_llava_shuffle_json(llava_json.replace(".json", ".shuffle.json"))


def split_llava_shuffle_json(llava_json):
    with File.open(llava_json, "r") as f:
        llava_data = json.load(f)
    # only get the onces that are with short answers
    short_llava_data = []
    for d in tqdm(llava_data):
        skip = False
        if len(d["conversations"]) > 2:
            continue
        if "image" not in d:
            continue
        for conv in d["conversations"]:
            if conv["from"] == "human":
                question = conv["value"]
                if "this region:" in question:
                    skip = True
                    break
            else:
                answer = conv["value"]
                if len(answer.split(" ")) > 5:
                    skip = True
                    break
                elif answer.startswith("["):
                    skip = True
                    break
        if not skip:
            short_llava_data.append(d)
    print(len(short_llava_data))
    # index = list(range(len(llava_data)))
    split_size = 5000
    split_size_name = split_size // 1000
    for index in tqdm(range(0, len(short_llava_data), split_size)):
        start = index
        end = min(index + split_size, len(llava_data))
        with File.open(llava_json.replace(".json", f".short.{split_size_name}k.{index//split_size}.json"), "w") as f:
            json.dump(short_llava_data[start:end], f)
        convert_to_llava_eval_format(llava_json.replace(".json", f".short.{split_size_name}k.{index//split_size}.json"))


def split_gqa_idk_shuffle_json(llava_json):
    with File.open(llava_json, "r") as f:
        llava_data = json.load(f)
    # index = list(range(len(llava_data)))
    split_size = 5000
    split_size_name = split_size // 1000
    for index in tqdm(range(0, len(llava_data), split_size)):
        start = index
        end = min(index + split_size, len(llava_data))
        with File.open(llava_json.replace(".json", f".{split_size_name}k.{index//split_size}.json"), "w") as f:
            json.dump(llava_data[start:end], f)
        convert_to_llava_eval_format(llava_json.replace(".json", f".{split_size_name}k.{index//split_size}.json"))
    

def convert_to_llava_eval_format(llava_json):
    with File.open(llava_json, "r") as f:
        llava_data = json.load(f)
    output = []
    for d in tqdm(llava_data):
        id_ = d["id"]
        image = d["image"]
        question = d["conversations"][0]["value"].replace("<image>\n", "")
        answer = d["conversations"][-1]["value"]
        output_dict = {"question_id": id_, "image": image, "text": question, "answer": answer, "category": "short"}
        for key in d:
            if key not in ["id", "image", "conversations"]:
                output_dict[key] = d[key]
        output.append(output_dict)
    with File.open(llava_json.replace(".json", f".eval.jsonl"), "w") as f:
        for d in output:
            f.write(json.dumps(d)+"\n")


def transform_vizwiz_to_llava_eval(vizwiz_json):
    with File.open(vizwiz_json, "r") as f:
        vizwiz_data = json.load(f)
    vizwiz_prompt = "\nWhen the provided information is insufficient, respond with 'Unanswerable'.\nAnswer the question using a single word or phrase."
    output = []
    for id_, d in tqdm(enumerate(vizwiz_data)):
        image = d["image"]
        question = d["question"]
        answers = [ans["answer"] for ans in d["answers"]]
        answer_type = d["answer_type"]
        answerable = d["answerable"]
        if not answerable:
            answer = "unanswerable"
        else:
            # answer is the most frequent answer
            answer = max(set(answers), key=answers.count)
        
        output_dict = {"question_id": id_, "image": image, "text": question, "answer": answer, "category": answer_type, "answers": answers, "answerable": answerable}
        output.append(output_dict)
    with File.open(vizwiz_json.replace(".json", f".llava_eval.jsonl"), "w") as f:
        for d in output:
            f.write(json.dumps(d)+"\n")

    output = []
    for id_, d in tqdm(enumerate(vizwiz_data)):
        image = d["image"]
        question = d["question"]
        answers = [ans["answer"] for ans in d["answers"]]
        answer_type = d["answer_type"]
        answerable = d["answerable"]
        if not answerable:
            answer = "unanswerable"
        else:
            # answer is the most frequent answer
            answer = max(set(answers), key=answers.count)
        
        output_dict = {"question_id": id_, "image": image, "text": question+vizwiz_prompt, "answer": answer, "category": answer_type, "answers": answers, "answerable": answerable}
        output.append(output_dict)
    with File.open(vizwiz_json.replace(".json", f".llava_eval.prompt.jsonl"), "w") as f:
        for d in output:
            f.write(json.dumps(d)+"\n")



def convert_mmhal_to_llava_eval(json_file):
    with File.open(json_file, "r") as f:
        mmhal_data = json.load(f)
    output = []
    for qid, d in tqdm(enumerate(mmhal_data)):
        image = d["image_src"].split("/")[-1]
        question = d["question"]
        answer = d["gt_answer"]
        output_dict = {"question_id": qid, "image": image, "text": question, "answer": answer}
        # get other meta data
        for k, v in d.items():
            if k not in ["question_id", "image_src", "question", "gt_answer"]:
                output_dict[k] = v
        output.append(output_dict)
    with File.open(json_file.replace(".json", f".llava_eval.jsonl"), "w") as f:
        for d in output:
            f.write(json.dumps(d)+"\n")


def convert_vqa_k_test_to_llava_eval(jsonl_file, image_folder="coco/train2017"):
    blob_dir = "<DATA_FOLDER>"
    vqa_k_test = [json.loads(line.strip()) for line in File.open(jsonl_file, "r")]
    output = []
    vqa_prompt = "\nAnswer the question using a single word or phrase."
    all_qids = set()
    for qid, d in tqdm(enumerate(vqa_k_test)):
        if len(d["labels"]) == 0:
            continue
        image_id = d["image_id"]
        image_file = f"{image_folder}/{image_id:012d}.jpg"
        if not File.isfile(os.path.join(blob_dir, image_file)):
            image_file  = image_file.replace("train2017", "val2017")
            assert File.isfile(os.path.join(blob_dir, image_file))
        question = d["question"]
        qid = d["question_id"]
        all_qids.add(qid)
        # answer is the key with largest value in d["labels"]
        answer = max(d["labels"], key=d["labels"].get)
        output_dict = {"question_id": qid, "image": image_file, "text": question, "answer": answer}
        # get other meta data
        for k, v in d.items():
            if k not in ["question_id", "image_id", "question", "gt_answer"]:
                output_dict[k] = v
        output.append(output_dict)
    with File.open(jsonl_file.replace(".jsonl", f".llava_eval.jsonl"), "w") as f:
        for d in output:
            f.write(json.dumps(d)+"\n")
    
    import random
    sampled_5k = random.sample(all_qids, 5000)
    assert len(set(sampled_5k)) == 5000
    with File.open(jsonl_file.replace(".jsonl", f"_5k.llava_eval.jsonl"), "w") as f:
        for d in output:
            if d["question_id"] in sampled_5k:
                f.write(json.dumps(d)+"\n")
    prompt_output = []
    for qid, d in tqdm(enumerate(output)):
        d["text"] = d["text"] + vqa_prompt
        prompt_output.append(d)
    print(f"Total number of questions: {len(output)}")
    with File.open(jsonl_file.replace(".jsonl", f".llava_eval.prompt.jsonl"), "w") as f:
        for d in prompt_output:
            f.write(json.dumps(d)+"\n")
    with File.open(jsonl_file.replace(".jsonl", f"_5k.llava_eval.prompt.jsonl"), "w") as f:
        for d in prompt_output:
            if d["question_id"] in sampled_5k:
                f.write(json.dumps(d)+"\n")


def convert_vqa_tdiuc_to_llava_eval(jsonl_file, image_folder="coco/train2017"):
    blob_dir = "<DATA_FOLDER>"
    vqa_k_test = [json.loads(line.strip()) for line in File.open(jsonl_file, "r")]
    output = []
    vqa_prompt = "\nWhen the provided information is insufficient, respond with 'Unanswerable'.\nAnswer the question using a single word or phrase."
    all_qids = set()
    for qid, d in tqdm(enumerate(vqa_k_test)):
        image_id = d["image_id"]
        image_file = f"{image_folder}/{image_id:012d}.jpg"
        if not File.isfile(os.path.join(blob_dir, image_file)):
            image_file  = image_file.replace("train2017", "val2017")
            assert File.isfile(os.path.join(blob_dir, image_file))
        question = d["question"]
        qid = d["question_id"]
        all_qids.add(qid)
        # answer is the key with largest value in d["labels"]
        answer = "unanswerable"
        output_dict = {"question_id": qid, "image": image_file, "text": question, "answer": answer}
        # get other meta data
        for k, v in d.items():
            if k not in ["question_id", "image_id", "question", "gt_answer"]:
                output_dict[k] = v
        output.append(output_dict)
    with File.open(jsonl_file.replace(".jsonl", f".llava_eval.jsonl"), "w") as f:
        for d in output:
            f.write(json.dumps(d)+"\n")
    
    import random
    sampled_5k = random.sample(all_qids, 5000)
    assert len(set(sampled_5k)) == 5000
    with File.open(jsonl_file.replace(".jsonl", f"_5k.llava_eval.jsonl"), "w") as f:
        for d in output:
            if d["question_id"] in sampled_5k:
                f.write(json.dumps(d)+"\n")
    prompt_output = []
    for qid, d in tqdm(enumerate(output)):
        d["text"] = d["text"] + vqa_prompt
        prompt_output.append(d)
    print(f"Total number of questions: {len(output)}")
    with File.open(jsonl_file.replace(".jsonl", f".llava_eval.prompt.jsonl"), "w") as f:
        for d in prompt_output:
            f.write(json.dumps(d)+"\n")
    with File.open(jsonl_file.replace(".jsonl", f"_5k.llava_eval.prompt.jsonl"), "w") as f:
        for d in prompt_output:
            if d["question_id"] in sampled_5k:
                f.write(json.dumps(d)+"\n")


def convert_unk_vqa_public_to_llava_eval(jsonl_file, image_folder="images-val"):
    blob_dir = "<DATA_FOLDER>/UNK-VQA"
    vqa_k_test = [json.loads(line.strip()) for line in File.open(jsonl_file, "r")]
    output = []
    vqa_prompt = "\nAnswer the question using a single word or phrase."
    for qid, d in tqdm(enumerate(vqa_k_test)):
        image_name = d["image_name"]
        image_file = f"{image_folder}/{image_name}"
        assert File.isfile(os.path.join(blob_dir, image_file))
        question = d["question"]
        qid = d["question_id"]
        answer = d["answer"]
        output_dict = {"question_id": qid, "image": image_file, "text": question, "answer": answer}
        # get other meta data
        for k, v in d.items():
            if k not in ["question_id", "image_name", "question", "answer"]:
                output_dict[k] = v
        output.append(output_dict)
    with File.open(jsonl_file.replace(".jsonl", f".llava_eval.jsonl"), "w") as f:
        for d in output:
            f.write(json.dumps(d)+"\n")

    prompt_output = []
    for qid, d in tqdm(enumerate(output)):
        d["text"] = d["text"] + vqa_prompt
        prompt_output.append(d)
    print(f"Total number of questions: {len(output)}")
    with File.open(jsonl_file.replace(".jsonl", f".llava_eval.prompt.jsonl"), "w") as f:
        for d in prompt_output:
            f.write(json.dumps(d)+"\n")


def sample_eval_subset(gt_jsonl_file, orginal_gt_jsonl_file, num_samples=100):
    data = [json.loads(el) for el in File.open(gt_jsonl_file, 'r')]
    orig_data = [json.loads(el) for el in File.open(orginal_gt_jsonl_file, 'r')]
    qid2orig = {q["question_id"]: q for q in orig_data}
    import random
    random.shuffle(data)
    sampled_data = data[:num_samples]
    output_file = gt_jsonl_file.replace(".jsonl", f".subset_{num_samples}.jsonl").replace("_idk", "")
    with File.open(output_file, "w") as f:
        for d in sampled_data:
            qid = d["question_id"]
            orig_d = qid2orig[qid]
            if orig_d["category"] == "unk" and d["answerable"] == 0:
                d["answer"] = orig_d["answer"]
            f.write(json.dumps(d)+'\n')


def revert_human_verified_to_orig(gt_jsonl_file, orginal_gt_jsonl_file):
    data = [json.loads(el) for el in File.open(gt_jsonl_file, 'r')]
    orig_data = [json.loads(el) for el in File.open(orginal_gt_jsonl_file, 'r')]
    qid2orig = {q["question_id"]: q for q in orig_data}
    output_file = gt_jsonl_file.replace("_idk", "")
    with File.open(output_file, "w") as f:
        for d in data:
            qid = d["question_id"]
            orig_d = qid2orig[qid]
            if orig_d["category"] == "unk" and d["answerable"] == 0:
                d["answer"] = orig_d["answer"]
            f.write(json.dumps(d)+'\n')



def perturb_answerable_images(validated_gt_jsonl, perturbed_jsonl):
    data = [json.loads(el) for el in File.open(validated_gt_jsonl, 'r')]
    perturbed_data = [json.loads(el) for el in File.open(perturbed_jsonl, 'r')]
    qid2p = {q["question_id"]: q for q in perturbed_data}
    output_file = validated_gt_jsonl.replace(".jsonl", f".perturb_answerable.jsonl")
    with File.open(output_file, "w") as f:
        for d in data:
            qid = d["question_id"]
            perturbed_d = qid2p[qid]
            if perturbed_d["category"] != "unk" and d["answerable"] == 1:
                d["image"] = perturbed_d["image"]
            f.write(json.dumps(d)+'\n')


def convert_llava_to_dpo_format(llava_format_file, save_path, debug=False):
    from datasets import Dataset
    assert File.isfile(llava_format_file), f"File not found: {llava_format_file}"
    llava_data = [json.loads(l) for l in File.open(llava_format_file, "r")]

    qid2q = {q["question_id"]: q for q in llava_data}
    if debug:
        save_path += "_debug"
    qid_pairs = []
    for key in tqdm(sorted(list(qid2q.keys()))):
        if "category" in qid2q[key]:
            assert qid2q[key]["category"] in ["know", "unk"], f"Unknown category: {qid2q[key]}"
            if qid2q[key]["category"] == "unk":
                if "." in key:
                    known_key = key.split("_")[0]+"_clean"
                else:
                    known_key = key.split("-")[0]
                assert known_key in qid2q, f"Known key not found: {known_key}"
                assert qid2q[known_key]["category"] in ["know", "unk"], f"known category: {qid2q[key]}"
                if qid2q[known_key]["category"] == "know":
                    qid_pairs.append((key, known_key))
                else:
                    print(f"Unk key: {key}, known_key: {known_key}")
        elif "answerable" in qid2q[key]:
            if qid2q[key]["answerable"] == 0:
                known_key = key.split("-")[0]
                if known_key not in qid2q:
                    known_key = key.split("_")[0]+"_clean"
                assert known_key in qid2q, f"Known key not found: {known_key}"
                if qid2q[key]["answerable"] == 1:
                    qid_pairs.append((key, known_key))
        else:
            raise ValueError(f"Unknown category: {qid2q[key]}")

    # print(qid_pairs)
    print(len(qid_pairs))
    
    formatted_data = {"train": []}
    
    for idd_count, (idk_qid, k_qid) in enumerate(qid_pairs):
        q = qid2q[idk_qid]
        q_orig = qid2q[k_qid]
        dp = {}
        dp["id"] = idk_qid
        question = q["text"]
        a_pt = q["answer"]
        a_orig = q_orig["answer"]
        dp_og, dp_pt = {}, {}
        dp_og["id"] = idd_count * 2
        dp_pt["id"] = idd_count * 2 + 1
        dp_og['prompt'], dp_pt['prompt'] = question, question
        dp_og['img_path'] = q_orig["image"]
        dp_pt['img_path'] = q["image"]
        dp_og['chosen'], dp_og['rejected'] = a_orig, a_pt
        dp_pt['chosen'], dp_pt['rejected'] = a_pt, a_orig
        formatted_data["train"].append(dp_og)
        formatted_data["train"].append(dp_pt)
        if debug:
            if idd_count > 100:
                break

    #dumping train data
    my_dataset = Dataset.from_list(formatted_data["train"])
    my_dataset.save_to_disk(save_path)


    # upload the checkpoint
    to_upload = [f for f in os.listdir(save_path)]
    with File.async_upload(enabled=True):
        for f in to_upload:
            filepath = os.path.join(save_path, f)
            content = open(filepath, "rb").read()
            with File.open(filepath, "wb") as fp:
                fp.write(content)

    print ("Done DPO dumping dataset")
    return


def convert_silkie_to_hadpo_format():
    dataset_path = "<DATA_FOLDER>/MMInstruction/vlfeedback_80k.jsonl"
    image_folder = "silkie/merged_images"
    feedback_data = [json.loads(line) for line in File.open(dataset_path)]

    from itertools import combinations


    converted_sample = []

    def llava_prompt_format(prompt):
        prompt = prompt.strip()
        if "<image>" not in prompt:
            prompt = "<image>\n"+prompt
        return prompt

    for _, sample in tqdm(enumerate(feedback_data)):
        comps = sample["completions"]
        prompt = llava_prompt_format(sample["prompt"])
        id_ = sample["id"]
        image_path = os.path.join(image_folder, f"{id_}.jpg")
        # print(comps)
        for comp_idx1, comp_idx2 in combinations(range(len(comps)), 2):
            anno1, anno2 = comps[comp_idx1]["annotations"], comps[comp_idx2]["annotations"]

            # get average scores
            try:
                avg_score1 = np.mean(
                    [
                        float(anno1[aspect]["Rating"])
                        for aspect in anno1
                    ]
                )
                avg_score2 = np.mean(
                    [
                        float(anno2[aspect]["Rating"])
                        for aspect in anno2
                    ]
                )
            except ValueError:
                continue
            
            # get chosen and rejected responses
            if avg_score1 > avg_score2:
                chosen = comps[comp_idx1]["response"]
                rejected = comps[comp_idx2]["response"]
            elif avg_score2 > avg_score1:
                chosen = comps[comp_idx2]["response"]
                rejected = comps[comp_idx1]["response"]
            else:
                continue

            converted_sample.append(
                {
                    "id": f"silkie_{id_}_{comp_idx1}_{comp_idx2}",
                    "image": image_path,
                    "chosen_conversations": [
                        {"from": "human", "value": prompt},
                        {"from": "gpt", "value": chosen},
                    ],
                    "reject_conversations": [
                        {"from": "human", "value": prompt},
                        {"from": "gpt", "value": rejected},
                    ],
                }
            )
    with File.open("<DATA_FOLDER>/hadpo-data/hadpo/llava-v1.5/silkie.json", "w") as f:
        print(len(converted_sample))
        json.dump(converted_sample, f)
    
    with File.open("<DATA_FOLDER>/hadpo-data/hadpo/llava-v1.5/ours_image_based.json", "r") as f:
        unk_data = json.load(f)
    merged_data = converted_sample + unk_data
    with File.open("<DATA_FOLDER>/hadpo-data/hadpo/llava-v1.5/silkie+ours_image_based.json", "w") as f:
        print(len(merged_data))
        json.dump(merged_data, f)
    
    unk_vqa_data = [d for d in unk_data if "gqa" not in d["image"]]
    merged_data = converted_sample + unk_vqa_data
    with File.open("<DATA_FOLDER>/hadpo-data/hadpo/llava-v1.5/silkie+unk.json", "w") as f:
        print(len(merged_data))
        json.dump(merged_data, f)


def convert_lrv_to_llava_format():
    lrv_annotation_data_1 =  "<DATA_FOLDER>/LRV-instruction/filter_cap1.json"
    lrv_annotation_data_2 =  "<DATA_FOLDER>/LRV-instruction/filter_cap.json"
    lrv_annotation_data_3 =  "<DATA_FOLDER>/LRV-instruction/filter_cap_more.json"
    lrv_annotation_data_4 =  "<DATA_FOLDER>/LRV-instruction/filter_cap_more1.json"
    lrv_chart =  "<DATA_FOLDER>/LRV-instruction/chart_release_update.json"

    lrv_annotation_data = [lrv_annotation_data_1, lrv_annotation_data_2, lrv_annotation_data_3, lrv_annotation_data_4]
    image_folder = "vg/VG_100K"

    lrv_data = []
    for lrv_annotation_file in lrv_annotation_data:
        lrv_data += json.load(File.open(lrv_annotation_file))
    out_data = []
    for idx, d in tqdm(enumerate(lrv_data), total=len(lrv_data)):
        image_id = d["image_id"]
        question = d["question"]
        answer = d["answer"]
        out_dict = {"id": f"{idx}_{image_id}", "image": os.path.join(image_folder, f"{image_id}.jpg"), "conversations": [{"from": "human", "value": "<image>\n" + question}, {"from": "gpt", "value": answer}]}
        out_data.append(out_dict)
    print(len(out_data))
    with File.open("<DATA_FOLDER>/LRV-instruction/no_chart.json", "w") as f:
        json.dump(out_data, f)
    
    lrv_chart_data = json.load(File.open(lrv_chart))
    image_folder = "LRV-instruction/chart_image/"
    for idx, d in tqdm(enumerate(lrv_chart_data), total=len(lrv_chart_data)):
        image_id = d["image_id"]
        question = d["question"]
        answer = d["answer"]
        image_path = os.path.join(image_folder, f"{image_id}")
        out_dict = {"id": f"{idx}_{image_id}", "image": image_path, "conversations": [{"from": "human", "value": "<image>\n" + question}, {"from": "gpt", "value": answer}]}
        out_data.append(out_dict)
    print(len(out_data))
    with File.open("<DATA_FOLDER>/LRV-instruction/with_chart.json", "w") as f:
        json.dump(out_data, f)


if __name__ == "__main__":
    from fire import Fire
    Fire()
